#
# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
#
# SPDX-License-Identifier: BSD-2-Clause
#

import solver
from solver import mk_smt_expr, to_smt_expr, smt_expr
import check
from check import restr_others, loops_to_split, ProofNode
from rep_graph import (mk_graph_slice, vc_num, vc_offs, vc_upto,
	vc_double_range, VisitCount, vc_offset_upto)
import rep_graph
from syntax import (mk_and, mk_cast, mk_implies, mk_not, mk_uminus, mk_var,
	foldr1, boolT, word32T, word8T, builtinTs, true_term, false_term,
	mk_word32, mk_word8, mk_times, Expr, Type, mk_or, mk_eq, mk_memacc,
	mk_num, mk_minus, mk_plus, mk_less)
import syntax
import logic

from target_objects import trace, printout
import target_objects
import itertools

last_knowledge = [1]

class NoSplit(Exception):
	pass

def get_loop_var_analysis_at (p, n):
	k = ('search_loop_var_analysis', n)
	if k in p.cached_analysis:
		return p.cached_analysis[k]
	for hook in target_objects.hooks ('loop_var_analysis'):
		res = hook (p, n)
		if res != None:
			p.cached_analysis[k] = res
			return res
	var_deps = p.compute_var_dependencies ()
	res = p.get_loop_var_analysis (var_deps, n)
	p.cached_analysis[k] = res
	return res

def get_loop_vars_at (p, n):
	vs = [var for (var, data) in get_loop_var_analysis_at (p, n)
			if data == 'LoopVariable'] + [mk_word32 (0)]
	vs.sort ()
	return vs

default_loop_N = 3

last_proof = [None]

def build_proof (p):
	init_hyps = check.init_point_hyps (p)
	proof = build_proof_rec (default_searcher, p, (), list (init_hyps))

	trace ('Built proof for %s' % p.name)
	printout (repr (proof))
	last_proof[0] = proof

	return proof

def split_sample_set (bound):
	ns = (range (10) + range (10, 20, 2)
		+ range (20, 40, 5) + range (40, 100, 10)
		+ range (100, 1000, 50))
	return [n for n in ns if n < bound]

last_find_split_limit = [0]

def find_split_limit (p, n, restrs, hyps, kind, bound = 51, must_find = True,
		hints = [], use_rep = None):
	tag = p.node_tags[n][0]
	trace ('Finding split limit: %d (%s)' % (n, tag))
	last_find_split_limit[0] = (p, n, restrs, hyps, kind)
	if use_rep == None:
		rep = mk_graph_slice (p, fast = True)
	else:
		rep = use_rep
	check_order = hints + split_sample_set (bound) + [bound]
	# bounds strictly outside this range won't be considered
	bound_range = [0, bound]
	best_bound_found = [None]
	def check (i):
		if i < bound_range[0]:
			return True
		if i > bound_range[1]:
			return False
		restrs2 = restrs + ((n, VisitCount (kind, i)), )
		pc = rep.get_pc ((n, restrs2))
		restrs3 = restr_others (p, restrs2, 2)
		epc = rep.get_pc (('Err', restrs3), tag = tag)
		hyp = mk_implies (mk_not (epc), mk_not (pc))
		res = rep.test_hyp_whyps (hyp, hyps)
		if res:
			trace ('split limit found: %d' % i)
			bound_range[1] = i - 1
			best_bound_found[0] = i
		else:
			bound_range[0] = i + 1
		return res

	map (check, check_order)
	while bound_range[0] <= bound_range[1]:
		split = (bound_range[0] + bound_range[1]) / 2
		check (split)

	bound = best_bound_found[0]
	if bound == None:
		trace ('No split limit found for %d (%s).' % (n, tag))
		if must_find:
			assert not 'split limit found'
	return bound

def get_split_limit (p, n, restrs, hyps, kind, bound = 51,
		must_find = True, est_bound = 1, hints = None):
	k = ('SplitLimit', n, restrs, tuple (hyps), kind)
	if k in p.cached_analysis:
		(lim, prev_bound) = p.cached_analysis[k]
		if lim != None or bound <= prev_bound:
			return lim
	if hints == None:
		hints = [est_bound, est_bound + 1, est_bound + 2]
	res = find_split_limit (p, n, restrs, hyps, kind,
		hints = hints, must_find = must_find, bound = bound)
	p.cached_analysis[k] = (res, bound)
	return res

def init_case_splits (p, hyps, tags = None):
	if 'init_case_splits' in p.cached_analysis:
		return p.cached_analysis['init_case_splits']
	if tags == None:
		tags = p.pairing.tags
	poss = logic.possible_graph_divs (p)
	if len (set ([p.node_tags[n][0] for n in poss])) < 2:
		return None
	rep = rep_graph.mk_graph_slice (p)
	assert all ([p.nodes[n].kind == 'Cond' for n in poss])
	pc_map = logic.dict_list ([(rep.get_pc ((c, ())), c)
		for n in poss for c in p.nodes[n].get_conts ()
		if c not in p.loop_data])
	no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()])
	err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag))
		for tag in p.pairing.tags]
	knowledge = EqSearchKnowledge (rep, hyps + err_pc_hyps, list (pc_map))
	last_knowledge[0] = knowledge
	pc_ids = knowledge.classify_vs ()
	id_n_map = logic.dict_list ([(i, n) for (pc, i) in pc_ids.iteritems ()
		for n in pc_map[pc]])
	tag_div_ns = [[[n for n in ns if p.node_tags[n][0] == t] for t in tags]
		for (i, ns) in id_n_map.iteritems ()]
	split_pairs = [(l_ns[0], r_ns[0]) for (l_ns, r_ns) in tag_div_ns
		if l_ns and r_ns]
	p.cached_analysis['init_case_splits'] = split_pairs
	return split_pairs

case_split_tr = []

def init_proof_case_split (p, restrs, hyps):
	ps = init_case_splits (p, hyps)
	if ps == None:
		return None
	p.cached_analysis.setdefault ('finished_init_case_splits', [])
	fin = p.cached_analysis['finished_init_case_splits']
	known_s = set.union (set (restrs), set (hyps))
	for rs in fin:
		if rs <= known_s:
			return None
	rep = rep_graph.mk_graph_slice (p)
	no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()])
	err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag))
		for tag in p.pairing.tags]
	for (n1, n2) in ps:
		pc = rep.get_pc ((n1, ()))
		if rep.test_hyp_whyps (pc, hyps + err_pc_hyps):
			continue
		if rep.test_hyp_whyps (mk_not (pc), hyps + err_pc_hyps):
			continue
		case_split_tr.append ((n1, restrs, hyps))
		return ('CaseSplit', ((n1, p.node_tags[n1][0]), [n1, n2]))
	fin.append (known_s)
	return None

# TODO: deal with all the code duplication between these two searches
class EqSearchKnowledge:
	def __init__ (self, rep, hyps, vs):
		self.rep = rep
		self.hyps = hyps
		self.v_ids = dict ([(v, 1) for v in vs])
		self.model_trace = []
		self.facts = set ()
		self.premise = foldr1 (mk_and, map (rep.interpret_hyp, hyps))

	def add_model (self, m):
		self.model_trace.append (m)
		update_v_ids_for_model2 (self, self.v_ids, m)

	def hyps_add_model (self, hyps):
		if hyps:
			test_expr = foldr1 (mk_and, hyps)
		else:
			# we want to learn something, either a new model, or
			# that all hyps are true. if there are no hyps,
			# learning they're all true is learning nothing.
			# instead force a model
			test_expr = false_term
		test_expr = mk_implies (self.premise, test_expr)
		m = {}
		(r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)],
			{}, model = m)
		if r == 'unsat':
			if not hyps:
				trace ('WARNING: EqSearchKnowledge: premise unsat.')
				trace ("  ... learning procedure isn't going to work.")
			for hyp in hyps:
				self.facts.add (hyp)
		else:
			assert r == 'sat', r
			self.add_model (m)

	def classify_vs (self):
		while not self.facts:
			hyps = v_id_eq_hyps (self.v_ids)
			if not hyps:
				break
			self.hyps_add_model (hyps)
		return self.v_ids

def update_v_ids_for_model2 (knowledge, v_ids, m):
	# first update the live variables
	ev = lambda v: eval_model_expr (m, knowledge.rep.solv, v)
	groups = logic.dict_list ([((k, ev (v)), v)
		for (v, k) in v_ids.iteritems ()])
	v_ids.clear ()
	for (i, kt) in enumerate (sorted (groups)):
		for v in groups[kt]:
			v_ids[v] = i

def v_id_eq_hyps (v_ids):
	groups = logic.dict_list ([(k, v) for (v, k) in v_ids.iteritems ()])
	hyps = []
	for vs in groups.itervalues ():
		for v in vs[1:]:
			hyps.append (mk_eq (v, vs[0]))
	return hyps

class SearchKnowledge:
	def __init__ (self, rep, name, restrs, hyps, tags, cand_elts = None):
		self.rep = rep
		self.name = name
		self.restrs = restrs
		self.hyps = hyps
		self.tags = tags
		if cand_elts != None:
			(loop_elts, r_elts) = cand_elts
		else:
			(loop_elts, r_elts) = ([], [])
		(pairs, vs) = init_knowledge_pairs (rep, loop_elts, r_elts)
		self.pairs = pairs
		self.v_ids = vs
		self.model_trace = []
		self.facts = set ()
		self.weak_splits = set ()
		self.premise = syntax.true_term
		self.live_pairs_trace = []

	def add_model (self, m):
		self.model_trace.append (m)
		update_v_ids_for_model (self, self.pairs, self.v_ids, m)

	def hyps_add_model (self, hyps, assert_progress = True):
		if hyps:
			test_expr = foldr1 (mk_and, hyps)
		else:
			# we want to learn something, either a new model, or
			# that all hyps are true. if there are no hyps,
			# learning they're all true is learning nothing.
			# instead force a model
			test_expr = false_term
		test_expr = mk_implies (self.premise, test_expr)
		m = {}
		(r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)],
			{}, model = m)
		if r == 'unsat':
			if not hyps:
				trace ('WARNING: SearchKnowledge: premise unsat.')
				trace ("  ... learning procedure isn't going to work.")
				return
			if assert_progress:
				assert not (set (hyps) <= self.facts), hyps
			for hyp in hyps:
				self.facts.add (hyp)
		else:
			assert r == 'sat', r
			self.add_model (m)
			if assert_progress:
				assert self.model_trace[-2:-1] != [m]

	def eqs_add_model (self, eqs, assert_progress = True):
		preds = [pred for vpair in eqs
			for pred in expand_var_eqs (self, vpair)
			if pred not in self.facts]

		self.hyps_add_model (preds,
			assert_progress = assert_progress)

	def add_weak_split (self, eqs):
		preds = [pred for vpair in eqs
                        for pred in expand_var_eqs (self, vpair)]
		self.weak_splits.add (tuple (sorted (preds)))

	def is_weak_split (self, eqs):
		preds = [pred for vpair in eqs
                        for pred in expand_var_eqs (self, vpair)]
		return tuple (sorted (preds)) in self.weak_splits

def init_knowledge_pairs (rep, loop_elts, cand_r_loop_elts):
	trace ('Doing search knowledge setup now.')
	v_is = [(i, i_offs, i_step,
		[(v, i, i_offs, i_step) for v in get_loop_vars_at (rep.p, i)])
		for (i, i_offs, i_step) in sorted (loop_elts)]
	l_vtyps = set ([v[0].typ for (_, _, _, vs) in v_is for v in vs])
	v_js = [(j, j_offs, j_step,
		[(v, j, j_offs, j_step) for v in get_loop_vars_at (rep.p, j)
			if v.typ in l_vtyps])
		for (j, j_offs, j_step) in sorted (cand_r_loop_elts)]
	vs = {}
	for (_, _, _, var_vs) in v_is + v_js:
		for v in var_vs:
			vs[v] = (v[0].typ, True)
	pairs = {}
	for (i, i_offs, i_step, i_vs) in v_is:
		for (j, j_offs, j_step, j_vs) in v_js:
			pair = ((i, i_offs, i_step), (j, j_offs, j_step))
			pairs[pair] = (i_vs, j_vs)
	trace ('... done.')
	return (pairs, vs)

def update_v_ids_for_model (knowledge, pairs, vs, m):
	rep = knowledge.rep
	# first update the live variables
	groups = {}
	for v in vs:
		(k, const) = vs[v]
		groups.setdefault (k, [])
		groups[k].append ((v, const))
	k_counter = 1
	vs.clear ()
	for k in groups:
		for (const, xs) in split_group (knowledge, m, groups[k]):
			for x in xs:
				vs[x] = (k_counter, const)
			k_counter += 1
	# then figure out which pairings are still viable
	needed_ks = set ()
	zero = syntax.mk_word32 (0)
	for (pair, data) in pairs.items ():
		if data[0] == 'Failed':
			continue
		(lvs, rvs) = data
		lv_ks = set ([vs[v][0] for v in lvs
			if v[0] == zero or not vs[v][1]])
		rv_ks = set ([vs[v][0] for v in rvs])
		miss_vars = lv_ks - rv_ks
		if miss_vars:
			lv_miss = [v[0] for v in lvs if vs[v][0] in miss_vars]
			pairs[pair] = ('Failed', lv_miss.pop ())
		else:
			needed_ks.update ([vs[v][0] for v in lvs + rvs])
	# then drop any vars which are no longer relevant
	for v in vs.keys ():
		if vs[v][0] not in needed_ks:
			del vs[v]

def get_entry_visits_up_to (rep, head, restrs, hyps):
	"""get the set of nodes visited on the entry path entry
	to the loop, up to and including the head point."""
	k = ('loop_visits_up_to', head, restrs, tuple (hyps))
	if k in rep.p.cached_analysis:
		return rep.p.cached_analysis[k]

	[entry] = get_loop_entry_sites (rep, restrs, hyps, head)
	frontier = set ([entry])
	up_to = set ()
	loop = rep.p.loop_body (head)
	while frontier:
		n = frontier.pop ()
		if n == head:
			continue
		new_conts = [n2 for n2 in rep.p.nodes[n].get_conts ()
			if n2 in loop if n2 not in up_to]
		up_to.update (new_conts)
		frontier.update (new_conts)
	rep.p.cached_analysis[k] = up_to
	return up_to

def get_nth_visit_restrs (rep, restrs, hyps, i, visit_num):
	"""get the nth (visit_num-th) visit to node i, using its loop head
	as a restriction point. tricky because there may be a loop entry point
	that brings us in with the loop head before i, or vice-versa."""
	head = rep.p.loop_id (i)
	if i in get_entry_visits_up_to (rep, head, restrs, hyps):
		# node i is in the set visited on the entry path, so
		# the head is visited no more often than it
		offs = 0
	else:
		# these are visited after the head point on the entry path,
		# so the head point is visited 1 more time than it.
		offs = 1
	return ((head, vc_num (visit_num + offs)), ) + restrs

def get_var_pc_var_list (knowledge, v_i):
	rep = knowledge.rep
	(v_i, i, i_offs, i_step) = v_i
	def get_var (k):
		restrs2 = get_nth_visit_restrs (rep, knowledge.restrs,
				knowledge.hyps, i, k)
		(pc, env) = rep.get_node_pc_env ((i, restrs2))
		return (to_smt_expr (pc, env, rep.solv),
			to_smt_expr (v_i, env, rep.solv))
	return [get_var (i_offs + (k * i_step))
		for k in [0, 1, 2]]

def expand_var_eqs (knowledge, (v_i, v_j)):
	if v_j == 'Const':
		pc_vs = get_var_pc_var_list (knowledge, v_i)
		(_, v0) = pc_vs[0]
		return [mk_implies (pc, mk_eq (v, v0))
			for (pc, v) in pc_vs[1:]]
	# sorting the vars guarantees we generate the same
	# mem eqs each time which is important for the solver
	(v_i, v_j) = sorted ([v_i, v_j])
	pc_vs = zip (get_var_pc_var_list (knowledge, v_i),
		get_var_pc_var_list (knowledge, v_j))
	return [pred for ((pc_i, v_i), (pc_j, v_j)) in pc_vs
		for pred in [mk_eq (pc_i, pc_j),
			mk_implies (pc_i, logic.mk_eq_with_cast (v_i, v_j))]]

word_ops = {'bvadd':lambda x, y: x + y, 'bvsub':lambda x, y: x - y,
	'bvmul':lambda x, y: x * y, 'bvurem':lambda x, y: x % y,
	'bvudiv':lambda x, y: x / y, 'bvand':lambda x, y: x & y,
	'bvor':lambda x, y: x | y, 'bvxor': lambda x, y: x ^ y,
	'bvnot': lambda x: ~ x, 'bvneg': lambda x: - x,
	'bvshl': lambda x, y: x << y, 'bvlshr': lambda x, y: x >> y}

bool_ops = {'=>':lambda x, y: (not x) or y, '=': lambda x, y: x == y,
	'not': lambda x: not x, 'true': lambda: True, 'false': lambda: False}

word_ineq_ops = {'=': (lambda x, y: x == y, 'Unsigned'),
	'bvult': (lambda x, y: x < y, 'Unsigned'),
	'word32-eq': (lambda x, y: x == y, 'Unsigned'),
	'bvule': (lambda x, y: x <= y, 'Unsigned'),
	'bvsle': (lambda x, y: x <= y, 'Signed'),
	'bvslt': (lambda x, y: x < y, 'Signed'),
}

def eval_model (m, s, toplevel = None):
	if s in m:
		return m[s]
	if toplevel == None:
		toplevel = s
	if type (s) == str:
		try:
			result = solver.smt_to_val (s)
		except Exception, e:
			trace ('Error with eval_model')
			trace (toplevel)
			raise e
		return result

	op = s[0]

	if op == 'ite':
		[_, b, x, y] = s
		b = eval_model (m, b, toplevel)
		assert b in [false_term, true_term]
		if b == true_term:
			result = eval_model (m, x, toplevel)
		else:
			result = eval_model (m, y, toplevel)
		m[s] = result
		return result

	xs = [eval_model (m, x, toplevel) for x in s[1:]]

	if op[0] == '_' and op[1] in ['zero_extend', 'sign_extend']:
		[_, ex_kind, n_extend] = op
		n_extend = int (n_extend)
		[x] = xs
		assert x.typ.kind == 'Word' and x.kind == 'Num'
		if ex_kind == 'sign_extend':
			val = get_signed_val (x)
		else:
			val = get_unsigned_val (x)
		result = mk_num (val, x.typ.num + n_extend)
	elif op[0] == '_' and op[1] == 'extract':
		[_, _, n_top, n_bot] = op
		n_top = int (n_top)
		n_bot = int (n_bot)
		[x] = xs
		assert x.typ.kind == 'Word' and x.kind == 'Num'
		length = (n_top - n_bot) + 1
		result = mk_num ((x.val >> n_bot) & ((1 << length) - 1), length)
	elif op[0] == 'store-word32':
		(m, p, v) = xs
		(naming, eqs) = m
		eqs = dict (eqs)
		eqs[p.val] = v.val
		eqs = tuple (sorted (eqs.items ()))
		result = (naming, eqs)
	elif op[0] == 'store-word8':
		(m, p, v) = xs
		p_al = p.val & -4
		shift = (p.val & 3) * 8
		(naming, eqs) = m
		eqs = dict (eqs)
		prev_v = eqs[p_al]
		mask_v = prev_v & (((1 << 32) - 1) ^ (255 << shift))
		new_v = mask_v | ((v.val & 255) << shift)
		eqs[p.val] = new_v
		eqs = tuple (sorted (eqs.items ()))
		result = (naming, eqs)
	elif op[0] == 'load-word32':
		(m, p) = xs
		(naming, eqs) = m
		eqs = dict (eqs)
		result = syntax.mk_word32 (eqs[p.val])
	elif op[0] == 'load-word8':
		(m, p) = xs
		p_al = p.val & -4
		shift = (p.val & 3) * 8
		(naming, eqs) = m
		eqs = dict (eqs)
		v = (eqs[p_al] >> shift) & 255
		result = syntax.mk_word8 (v)
	elif xs and xs[0].typ.kind == 'Word' and op in word_ops:
		for x in xs:
			assert x.kind == 'Num', (s, op, x)
		result = word_ops[op](* [x.val for x in xs])
		result = result & ((1 << xs[0].typ.num) - 1)
		result = Expr ('Num', xs[0].typ, val = result)
	elif xs and xs[0].typ.kind == 'Word' and op in word_ineq_ops:
		(oper, signed) = word_ineq_ops[op]
		if signed == 'Signed':
			result = oper (* map (get_signed_val, xs))
		else:
			assert signed == 'Unsigned'
			result = oper (* [x.val for x in xs])
		result = {True: true_term, False: false_term}[result]
	elif op == 'and':
		result = all ([x == true_term for x in xs])
		result = {True: true_term, False: false_term}[result]
	elif op == 'or':
		result = bool ([x for x in xs if x == true_term])
		result = {True: true_term, False: false_term}[result]
	elif op in bool_ops:
		assert all ([x.typ == boolT for x in xs])
		result = bool_ops[op](* [x == true_term for x in xs])
		result = {True: true_term, False: false_term}[result]
	else:
		assert not 's_expr handled', (s, op)
	m[s] = result
	return result

def get_unsigned_val (x):
	assert x.typ.kind == 'Word'
	assert x.kind == 'Num'
	bits = x.typ.num
	v = x.val & ((1 << bits) - 1)
	return v

def get_signed_val (x):
	assert x.typ.kind == 'Word'
	assert x.kind == 'Num'
	bits = x.typ.num
	v = x.val & ((1 << bits) - 1)
	if v >= (1 << (bits - 1)):
		v = v - (1 << bits)
	return v

def short_array_str (arr):
	items = [('%x: %x' % (p.val * 4, v.val))
		for (p, v) in arr.iteritems ()
		if type (p) != str]
	items.sort ()
	return '{' + ', '.join (items) + '}'

def eval_model_expr (m, solv, v):
	s = solver.smt_expr (v, {}, solv)
	s_x = solver.parse_s_expression (s)

	return eval_model (m, s_x)

def model_equal (m, knowledge, vpair):
	preds = expand_var_eqs (knowledge, vpair)
	for pred in preds:
		x = eval_model_expr (m, knowledge.rep.solv, pred)
		assert x in [syntax.true_term, syntax.false_term]
		if x == syntax.false_term:
			return False
	return True

def get_model_trace (knowledge, m, v):
	rep = knowledge.rep
	pc_vs = get_var_pc_var_list (knowledge, v)
	trace = []
	for (pc, v) in pc_vs:
		x = eval_model_expr (m, rep.solv, pc)
		assert x in [syntax.true_term, syntax.false_term]
		if x == syntax.false_term:
			trace.append (None)
		else:
			trace.append (eval_model_expr (m, rep.solv, v))
	return tuple (trace)

def split_group (knowledge, m, group):
	group = list (set (group))
	if group[0][0][0].typ == syntax.builtinTs['Mem']:
		bins = []
		for (v, const) in group:
			for i in range (len (bins)):
				if model_equal (m, knowledge,
						(v, bins[i][1][0])):
					bins[i][1].append (v)
					break
			else:
				if const:
					const = model_equal (m, knowledge,
						(v, 'Const'))
				bins.append ((const, [v]))
		return bins
	else:
		bins = {}
		for (v, const) in group:
			trace = get_model_trace (knowledge, m, v)
			if trace not in bins:
				tconst = len (set (trace) - set ([None])) <= 1
				bins[trace] = (const and tconst, [])
			bins[trace][1].append (v)
		return bins.values ()

def mk_pairing_v_eqs (knowledge, pair, endorsed = True):
	v_eqs = []
	(lvs, rvs) = knowledge.pairs[pair]
	zero = mk_word32 (0)
	for v_i in lvs:
		(k, const) = knowledge.v_ids[v_i]
		if const and v_i[0] != zero:
			if not endorsed or eq_known (knowledge, (v_i, 'Const')):
				v_eqs.append ((v_i, 'Const'))
				continue
		vs_j = [v_j for v_j in rvs if knowledge.v_ids[v_j][0] == k]
		if endorsed:
			vs_j = [v_j for v_j in vs_j
				if eq_known (knowledge, (v_i, v_j))]
		if not vs_j:
			return None
		v_j = vs_j[0]
		v_eqs.append ((v_i, v_j))
	return v_eqs

def eq_known (knowledge, vpair):
	preds = expand_var_eqs (knowledge, vpair)
	return set (preds) <= knowledge.facts

def find_split_loop (p, head, restrs, hyps, unfold_limit = 9,
		node_restrs = None, trace_ind_fails = None):
	assert p.loop_data[head][0] == 'Head'
	assert p.node_tags[head][0] == p.pairing.tags[0]

	# the idea is to loop through testable hyps, starting with ones that
	# need smaller models (the most unfolded models will time out for
	# large problems like finaliseSlot)

	rep = mk_graph_slice (p, fast = True)

	nec = get_necessary_split_opts (p, head, restrs, hyps)
	if nec and nec[0] in ['CaseSplit', 'LoopUnroll']:
		return nec
	elif nec:
		i_j_opts = nec
	else:
		i_j_opts = default_i_j_opts (unfold_limit)

	if trace_ind_fails == None:
		ind_fails = []
	else:
		ind_fails = trace_ind_fails
	for (i_opts, j_opts) in i_j_opts:
		result = find_split (rep, head, restrs, hyps,
			i_opts, j_opts, node_restrs = node_restrs)
		if result[0] != None:
			return result
		ind_fails.extend (result[1])

	if ind_fails:
		trace ('Warning: inductive failures: %s' % ind_fails)
	raise NoSplit ()

def default_i_j_opts (unfold_limit = 9):
	return mk_i_j_opts (unfold_limit = unfold_limit)

def mk_i_j_opts (i_seq_opts = None, j_seq_opts = None, unfold_limit = 9):
	if i_seq_opts == None:
		i_seq_opts = [(0, 1), (1, 1), (2, 1), (3, 1)]
	if j_seq_opts == None:
		j_seq_opts = [(0, 1), (0, 2), (1, 1), (1, 2),
			(2, 1), (2, 2), (3, 1)]
	all_opts = set (i_seq_opts + j_seq_opts)

	def filt (opts, lim):
		return [(start, step) for (start, step) in opts
			if start + (2 * step) + 1 <= lim]

	lims = [(filt (i_seq_opts, lim), filt (j_seq_opts, lim))
		for lim in range (unfold_limit)
		if [1 for (start, step) in all_opts
			if start + (2 * step) + 1 == lim]]
	lims = [(i_opts, j_opts) for (i_opts, j_opts) in lims
		if i_opts and j_opts]
	return lims

necessary_split_opts_trace = []

def get_interesting_linear_series_exprs (p, head):
	k = ('interesting_linear_series', head)
	if k in p.cached_analysis:
		return p.cached_analysis[k]
	res = logic.interesting_linear_series_exprs (p, head,
		get_loop_var_analysis_at (p, head))
	p.cached_analysis[k] = res
	return res

def split_opt_test (p, tags = None):
	if not tags:
		tags = p.pairing.tags
	heads = [head for head in init_loops_to_split (p, ())
		if p.node_tags[head][0] == tags[0]]
	hyps = check.init_point_hyps (p)
	return [(head, get_necessary_split_opts (p, head, (), hyps))
		for head in heads]

def interesting_linear_test (p):
	p.do_analysis ()
	for head in p.loop_heads ():
		inter = get_interesting_linear_series_exprs (p, head)
		hooks = target_objects.hooks ('loop_var_analysis')
		n_exprs = [(n, expr, offs) for (n, vs) in inter.iteritems ()
			if not [hook for hook in hooks if hook (p, n) != None]
			for (kind, expr, offs) in vs]
		if n_exprs:
			rep = rep_graph.mk_graph_slice (p)
		for (n, expr, offs) in n_exprs:
			restrs = tuple ([(n2, vc) for (n2, vc)
				in restr_others_both (p, (), 2, 2)
				if p.loop_id (n2) != p.loop_id (head)])
			vis1 = (n, ((head, vc_offs (1)), ) + restrs)
			vis2 = (n, ((head, vc_offs (2)), ) + restrs)
			pc = rep.get_pc (vis2)
			imp = mk_implies (pc, mk_eq (rep.to_smt_expr (expr, vis2),
				rep.to_smt_expr (mk_plus (expr, offs), vis1)))
			assert rep.test_hyp_whyps (imp, [])
	return True

last_necessary_split_opts = [0]

def get_necessary_split_opts (p, head, restrs, hyps, tags = None):
	if not tags:
		tags = p.pairing.tags
	[l_tag, r_tag] = tags
	last_necessary_split_opts[0] = (p, head, restrs, hyps, tags)

	rep = rep_graph.mk_graph_slice (p, fast = True)
	entries = get_loop_entry_sites (rep, restrs, hyps, head)
	if len (entries) > 1:
		return ('CaseSplit', ((entries[0], tags[0]), [entries[0]]))
	for n in init_loops_to_split (p, restrs):
		if p.node_tags[n][0] != r_tag:
			continue
		entries = get_loop_entry_sites (rep, restrs, hyps, n)
		if len (entries) > 1:
			return ('CaseSplit', ((entries[0], r_tag),
				[entries[0]]))

	stuff = linear_setup_stuff (rep, head, restrs, hyps, tags)
	if stuff == None:
		return None
	seq_eqs = get_matching_linear_seqs (rep, head, restrs, hyps, tags)

	vis = stuff['vis']
	for v in seq_eqs:
		if v[0] == 'LoopUnroll':
			(_, n, est_bound) = v
			lim = get_split_limit (p, n, restrs, hyps, 'Number',
				est_bound = est_bound, must_find = False)
			if lim != None:
				return ('LoopUnroll', n)
			continue
		((n, expr), (n2, expr2), (l_start, l_step), (r_start, r_step),
			_, _) = v
		eqs = [rep_graph.eq_hyp ((expr,
			(vis (n, l_start + (i * l_step)), l_tag)),
			(expr2, (vis (n2, r_start + (i * r_step)), r_tag)))
			for i in range (2)]
		vis_hyp = rep_graph.pc_true_hyp ((vis (n, l_start), l_tag))
		vis_hyps = [vis_hyp] + stuff['hyps']
		eq = foldr1 (mk_and, map (rep.interpret_hyp, eqs))
		m = {}
		if rep.test_hyp_whyps (eq, vis_hyps, model = m):
			trace ('found necessary split info: (%s, %s), (%s, %s)'
				% (l_start, l_step, r_start, r_step))
			return mk_i_j_opts ([(l_start + i, l_step)
					for i in range (r_step + 1)],
				[(r_start + i, r_step)
					for i in range (l_step + 1)],
				unfold_limit = 100)
		n_vcs = entry_path_no_loops (rep, l_tag, m, head)
		path_hyps = [rep_graph.pc_true_hyp ((n_vc, l_tag)) for n_vc in n_vcs]
		if rep.test_hyp_whyps (eq, stuff['hyps'] + path_hyps):
			# immediate case split on difference between entry paths
			checks = [(stuff['hyps'], eq_hyp, 'eq')
				for eq_hyp in eqs]
			return derive_case_split (rep, n_vcs, checks)
		necessary_split_opts_trace.append ((n, expr, (l_start, l_step),
			(r_start, r_step), 'Seq check failed'))
	return None

def linear_setup_stuff (rep, head, restrs, hyps, tags):
	[l_tag, r_tag] = tags
	k = ('linear_seq setup', head, restrs, tuple (hyps), tuple (tags))
	p = rep.p
	if k in p.cached_analysis:
		return p.cached_analysis[k]

	assert p.node_tags[head][0] == l_tag
	l_seq_vs = get_interesting_linear_series_exprs (p, head)
	if not l_seq_vs:
		return None
	r_seq_vs = {}
	restr_env = {p.loop_id (head): restrs}
	for n in init_loops_to_split (p, restrs):
		if p.node_tags[n][0] != r_tag:
			continue
		vs = get_interesting_linear_series_exprs (p, n)
		r_seq_vs.update (vs)
	if not r_seq_vs:
		return None

	def vis (n, i):
		restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, i)
		return (n, restrs2)
	smt = lambda expr, n, i: rep.to_smt_expr (expr, vis (n, i))
	smt_pc = lambda n, i: rep.get_pc (vis (n, i))

	# remove duplicates by concretising
	l_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset))
		for n in l_seq_vs
		for (kind, expr, offs, oset) in l_seq_vs[n]]).values ()
	r_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset))
                for n in r_seq_vs
		for (kind, expr, offs, oset) in r_seq_vs[n]]).values ()

	hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), r_tag))
		for n in set ([n for (_, n, _, _, _) in r_seq_vs])]
	hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), l_tag))
		for n in set ([n for (_, n, _, _, _) in l_seq_vs])]
	hyps = hyps + [check.non_r_err_pc_hyp (tags,
			restr_others (p, restrs, 2))]

	r = {'l_seq_vs': l_seq_vs, 'r_seq_vs': r_seq_vs,
		'hyps': hyps, 'vis': vis, 'smt': smt, 'smt_pc': smt_pc}
	p.cached_analysis[k] = r
	return r

def get_matching_linear_seqs (rep, head, restrs, hyps, tags):
	k = ('matching linear seqs', head, restrs, tuple (hyps), tuple (tags))
	p = rep.p
	if k in p.cached_analysis:
		v = p.cached_analysis[k]
		(x, y) = itertools.tee (v[0])
		v[0] = x
		return y

	[l_tag, r_tag] = tags
	stuff = linear_setup_stuff (rep, head, restrs, hyps, tags)
	if stuff == None:
		return []

	hyps = stuff['hyps']
	vis = stuff['vis']

	def get_model (n, offs):
		m = {}
		offs_smt = stuff['smt'] (offs, n, 1)
		eq = mk_eq (mk_times (offs_smt, mk_num (4, offs_smt.typ)),
			mk_num (0, offs_smt.typ))
		ex_hyps = [rep_graph.pc_true_hyp ((vis (n, 1), l_tag)),
			rep_graph.pc_true_hyp ((vis (n, 2), l_tag))]
		res = rep.test_hyp_whyps (eq, hyps + ex_hyps, model = m)
		if not m:
			necessary_split_opts_trace.append ((n, kind, 'NoModel'))
			return None
		return m

	r = (seq_eq
		for (kind, n, expr, offs, oset) in sorted (stuff['l_seq_vs'])
		if [v for v in stuff['r_seq_vs'] if v[0] == kind]
		for m in [get_model (n, offs)]
		if m
		for seq_eq in [get_linear_seq_eq (rep, m, stuff,
					(kind, n, expr, offs, oset)),
			get_model_r_side_unroll (rep, tags, m,
				restrs, hyps, stuff)]
		if seq_eq != None)
	(x, y) = itertools.tee (r)
	p.cached_analysis[k] = [y]
	return x

def get_linear_seq_eq (rep, m, stuff, expr_t1):
	def get_int_min (expr):
		v = eval_model_expr (m, rep.solv, expr)
		assert v.kind == 'Num', v
		vs = [v.val + (i << v.typ.num) for i in range (-2, 3)]
		(_, v) = min ([(abs (v), v) for v in vs])
		return v
	(kind, n1, expr1, offs1, oset1) = expr_t1
	smt = stuff['smt']
	expr_init = smt (expr1, n1, 0)
	expr_v = get_int_min (expr_init)
	offs_v = get_int_min (smt (offs1, n1, 1))
	r_seqs = [(n, expr, offs, oset2,
			get_int_min (mk_minus (expr_init, smt (expr, n, 0))),
			get_int_min (smt (offs, n, 0)))
		for (kind2, n, expr, offs, oset2) in sorted (stuff['r_seq_vs'])
		if kind2 == kind]

	for (n, expr, offs2, oset2, diff, offs_v2) in sorted (r_seqs):
		mult = offs_v / offs_v2
		if offs_v % offs_v2 != 0 or mult > 8:
			necessary_split_opts_trace.append ((n, expr,
				'StepWrong', offs_v, offs_v2))
		elif diff % offs_v2 != 0 or (diff * offs_v2) < 0 or (diff / offs_v2) > 8:
			necessary_split_opts_trace.append ((n, expr,
				'StartWrong', diff, offs_v2))
		else:
			return ((n1, expr1), (n, expr), (0, 1),
				(diff / offs_v2, mult), (offs1, offs2),
				(oset1, oset2))
	return None

last_r_side_unroll = [None]

def get_model_r_side_unroll (rep, tags, m, restrs, hyps, stuff):
	p = rep.p
	[l_tag, r_tag] = tags
	last_r_side_unroll[0] = (rep, tags, m, restrs, hyps, stuff)

	r_kinds = set ([kind for (kind, n, _, _, _) in stuff['r_seq_vs']])
	l_visited_ns_vcs = logic.dict_list ([(n, vc)
		for (tag, n, vc) in rep.node_pc_env_order
		if tag == l_tag
		if eval_pc (rep, m, (n, vc))])
	l_arc_interesting = [(n, vc, kind, expr)
		for (n, vcs) in l_visited_ns_vcs.iteritems ()
		if len (vcs) == 1
		for vc in vcs
		for (kind, expr)
			in logic.interesting_node_exprs (p, n, tags = tags)
		if kind in r_kinds
		if expr.typ.kind == 'Word']
	l_kinds = set ([kind for (n, vc, kind, _) in l_arc_interesting])

	# FIXME: cloned
	def canon_n (n, typ):
		vs = [n + (i << typ.num) for i in range (-2, 3)]
		(_, v) = min ([(abs (v), v) for v in vs])
		return v
	def get_int_min (expr):
		v = eval_model_expr (m, rep.solv, expr)
		assert v.kind == 'Num', v
		return canon_n (v.val, v.typ)
	def eval (expr, n, vc):
		expr = rep.to_smt_expr (expr, (n, vc))
		return get_int_min (expr)

	val_interesting_map = logic.dict_list ([((kind, eval (expr, n, vc)), n)
		for (n, vc, kind, expr) in l_arc_interesting])

	smt = stuff['smt']

	for (kind, n, expr, offs, _) in stuff['r_seq_vs']:
		if kind not in l_kinds:
			continue
		if expr.typ.kind != 'Word':
			continue
		expr_n = get_int_min (smt (expr, n, 0))
		offs_n = get_int_min (smt (offs, n, 0))
		hit = ([i for i in range (64)
			if (kind, canon_n (expr_n + (offs_n * i), expr.typ))
				in val_interesting_map])
		if [i for i in hit if i > 4]:
			return ('LoopUnroll', p.loop_id (n), max (hit))
	return None

last_failed_pairings = []

def setup_split_search (rep, head, restrs, hyps,
		i_opts, j_opts, unfold_limit = None, tags = None,
		node_restrs = None):
	p = rep.p

	if not tags:
		tags = p.pairing.tags
	if node_restrs == None:
		node_restrs = set (p.nodes)
	if unfold_limit == None:
		unfold_limit = max ([start + (2 * step) + 1
			for (start, step) in i_opts + j_opts])

	trace ('Split search at %d, unfold limit %d.' % (head, unfold_limit))

	l_tag, r_tag = tags
	loop_elts = [(n, start, step) for n in p.splittable_points (head)
		if n in node_restrs
		for (start, step) in i_opts]
	init_to_split = init_loops_to_split (p, restrs)
	r_to_split = [n for n in init_to_split if p.node_tags[n][0] == r_tag]
	cand_r_loop_elts = [(n2, start, step) for n in r_to_split
		for n2 in p.splittable_points (n)
		if n2 in node_restrs
		for (start, step) in j_opts]

	err_restrs = restr_others (p, tuple ([(sp, vc_upto (unfold_limit))
		for sp in r_to_split]) + restrs, 1)
	nrerr_pc = mk_not (rep.get_pc (('Err', err_restrs), tag = r_tag))

	def get_pc (n, k):
		restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, k)
		return rep.get_pc ((n, restrs2))

	for n in r_to_split:
		get_pc (n, unfold_limit)
	get_pc (head, unfold_limit)

	premise = foldr1 (mk_and, [nrerr_pc] + map (rep.interpret_hyp, hyps))
	premise = logic.weaken_assert (premise)

	knowledge = SearchKnowledge (rep,
		'search at %d (unfold limit %d)' % (head, unfold_limit),
		restrs, hyps, tags, (loop_elts, cand_r_loop_elts))
	knowledge.premise = premise
	last_knowledge[0] = knowledge

	# make sure the representation is in sync
	rep.test_hyp_whyps (true_term, hyps)

	# make sure all mem eqs are being tracked
	mem_vs = [v for v in knowledge.v_ids if v[0].typ == builtinTs['Mem']]
	for (i, v) in enumerate (mem_vs):
		for v2 in mem_vs[:i]:
			for pred in expand_var_eqs (knowledge, (v, v2)):
				smt_expr (pred, {}, rep.solv)
	for v in knowledge.v_ids:
		for pred in expand_var_eqs (knowledge, (v, 'Const')):
			smt_expr (pred, {}, rep.solv)

	return knowledge

def get_loop_entry_sites (rep, restrs, hyps, head):
	k = ('loop_entry_sites', restrs, tuple (hyps), rep.p.loop_id (head))
	if k in rep.p.cached_analysis:
		return rep.p.cached_analysis[k]
	ns = set ([n for n2 in rep.p.loop_body (head)
		for n in rep.p.preds[n2]
		if rep.p.loop_id (n) == None])
	def npc (n):
		return rep_graph.pc_false_hyp (((n, tuple ([(n2, restr)
			for (n2, restr) in restrs if n2 != n])),
				rep.p.node_tags[n][0]))
	res = [n for n in ns if not rep.test_hyp_imp (hyps, npc (n))]
	rep.p.cached_analysis[k] = res
	return res

def rebuild_knowledge (head, knowledge):
	i_opts = sorted (set ([(start, step)
		for ((_, start, step), _) in knowledge.pairs]))
	j_opts = sorted (set ([(start, step)
		for (_, (_, start, step)) in knowledge.pairs]))
	knowledge2 = setup_split_search (knowledge.rep, head, knowledge.restrs,
		knowledge.hyps, i_opts, j_opts)
	knowledge2.facts.update (knowledge.facts)
	for m in knowledge.model_trace:
		knowledge2.add_model (m)
	return knowledge2

def split_search (head, knowledge):
	rep = knowledge.rep
	p = rep.p

	# test any relevant cached solutions.
	p.cached_analysis.setdefault (('v_eqs', head), set ())
	v_eq_cache = p.cached_analysis[('v_eqs', head)]
	for (pair, eqs) in v_eq_cache:
		if pair in knowledge.pairs:
			knowledge.eqs_add_model (list (eqs),
				assert_progress = False)

	while True:
		trace ('In %s' % knowledge.name)
		trace ('Computing live pairings')
		pair_eqs = [(pair, mk_pairing_v_eqs (knowledge, pair))
			for pair in sorted (knowledge.pairs)
			if knowledge.pairs[pair][0] != 'Failed']
		if not pair_eqs:
			ind_fails = trace_search_fail (knowledge)
			return (None, ind_fails)

		endorsed = [(pair, eqs) for (pair, eqs) in pair_eqs
			if eqs != None]
		trace (' ... %d live pairings, %d endorsed' %
			(len (pair_eqs), len (endorsed)))
		knowledge.live_pairs_trace.append (len (pair_eqs))
		for (pair, eqs) in endorsed:
			if knowledge.is_weak_split (eqs):
				trace ('  dropping endorsed - probably weak.')
				knowledge.pairs[pair] = ('Failed',
					'ExpectedSplitWeak', eqs)
				continue
			split = build_and_check_split (p, pair, eqs,
				knowledge.restrs, knowledge.hyps,
				knowledge.tags)
			if split == None:
				knowledge.pairs[pair] = ('Failed',
					'SplitWeak', eqs)
				knowledge.add_weak_split (eqs)
				continue
			elif split == 'InductFailed':
				knowledge.pairs[pair] = ('Failed',
					'InductFailed', eqs)
			elif split[0] == 'SingleRevInduct':
				return split
			else:
				v_eq_cache.add ((pair, tuple (eqs)))
				trace ('Found split!')
				return ('Split', split)
		if endorsed:
			continue

		(pair, _) = pair_eqs[0]
		trace ('Testing guess for pair: %s' % str (pair))
		eqs = mk_pairing_v_eqs (knowledge, pair, endorsed = False)
		assert eqs, pair
		knowledge.eqs_add_model (eqs)

def build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags):
	split = v_eqs_to_split (p, pair, eqs, restrs, hyps, tags = tags)
	if split == None:
		return None
	res = check_split_induct (p, restrs, hyps, split, tags = tags)
	if res:
		return split
	else:
		return 'InductFailed'

def build_and_check_split (p, pair, eqs, restrs, hyps, tags):
	res = build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags)
	if res != 'InductFailed':
		return res

	# induction has failed at this point, but we might be able to rescue
	# it one of two different ways.
	((l_split, _, l_step), _) = pair
	extra = get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step)
	if extra:
		res = build_and_check_split (p, pair, eqs, restrs, hyps, tags)
		# the additional linear eqs get built into the result
		if res != 'InductFailed':
			return res

	(_, (r_split, _, _)) = pair
	r_loop = p.loop_id (r_split)
	spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_loop)
	if spec:
		hyp = check.single_induct_resulting_hyp (p, restrs, spec)
		hyps2 = hyps + [hyp]
		res = build_and_check_split (p, pair, eqs, restrs, hyps2, tags)
		if res != 'InductFailed':
			return ('SingleRevInduct', spec)
	return 'InductFailed'

def get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step):
	k = ('extra_linear_seq_eqs', l_split, l_step)
	if k in p.cached_analysis:
		return []
	if not [v for (v, data) in get_loop_var_analysis_at (p, l_split)
			if data[0] == 'LoopLinearSeries']:
		return []
	import loop_bounds
	lin_series_eqs = loop_bounds.get_linear_series_eqs (p, l_split,
		restrs, [], omit_standard = True)
	p.cached_analysis[k] = lin_series_eqs
	return lin_series_eqs

def trace_search_fail (knowledge):
	trace (('Exhausted split candidates for %s' % knowledge.name))
	fails = [it for it in knowledge.pairs.items ()
		if it[1][0] == 'Failed']
	last_failed_pairings.append (fails)
	del last_failed_pairings[:-10]
	fails10 = fails[:10]
	trace ('  %d of %d failed pairings:' % (len (fails10),
		len (fails)))
	for f in fails10:
		trace ('    %s' % (f,))
	ind_fails = [it for it in fails
		if str (it[1][1]) == 'InductFailed']
	if ind_fails:
		trace (  'Inductive failures!')
	else:
		trace (  'No inductive failures.')
	for f in ind_fails:
		trace ('    %s' % (f,))
	return ind_fails

def find_split (rep, head, restrs, hyps, i_opts, j_opts,
		unfold_limit = None, tags = None,
		node_restrs = None):
	knowledge = setup_split_search (rep, head, restrs, hyps,
		i_opts, j_opts, unfold_limit = unfold_limit,
		tags = tags, node_restrs = node_restrs)

	res = split_search (head, knowledge)

	if res[0]:
		return res

	(models, facts, n_vcs) = most_common_path (head, knowledge)
	if not n_vcs:
		return res

	[tag, _] = knowledge.tags
	knowledge = setup_split_search (rep, head, restrs,
		hyps + [rep_graph.pc_true_hyp ((n_vc, tag)) for n_vc in n_vcs],
		i_opts, j_opts, unfold_limit, tags, node_restrs = node_restrs)
	knowledge.facts.update (facts)
	for m in models:
		knowledge.add_model (m)
	res = split_search (head, knowledge)

	if res[0] == None:
		return res
	(_, split) = res
	checks = check.split_init_step_checks (rep.p, restrs,
                        hyps, split)

	return derive_case_split (rep, n_vcs, checks)

def most_common_path (head, knowledge):
	rep = knowledge.rep
	[tag, _] = knowledge.tags
	data = logic.dict_list ([(tuple (entry_path_no_loops (rep,
			tag, m, head)), m)
		for m in knowledge.model_trace])
	if len (data) < 2:
		return (None, None, None)

	(_, path) = max ([(len (data[path]), path) for path in data])
	models = data[path]
	facts = knowledge.facts
	other_n_vcs = set.intersection (* [set (path2) for path2 in data
		if path2 != path])

	n_vcs = []
	pcs = set ()
	for n_vc in path:
		if n_vc in other_n_vcs:
			continue
		if rep.p.loop_id (n_vc[0]):
			continue
		pc = rep.get_pc (n_vc)
		if pc not in pcs:
			pcs.add (pc)
			n_vcs.append (n_vc)
	assert n_vcs

	return (models, facts, n_vcs)

def eval_pc (rep, m, n_vc, tag = None):
	hit = eval_model_expr (m, rep.solv, rep.get_pc (n_vc, tag = tag))
	assert hit in [syntax.true_term, syntax.false_term], (n_vc, hit)
	return hit == syntax.true_term

def entry_path (rep, tag, m, head):
	n_vcs = []
	for (tag2, n, vc) in rep.node_pc_env_order:
		if n == head:
			break
		if tag2 != tag:
			continue
		if eval_pc (rep, m, (n, vc), tag):
			n_vcs.append ((n, vc))
	return n_vcs

def entry_path_no_loops (rep, tag, m, head = None):
	n_vcs = entry_path (rep, tag, m, head)
	return [(n, vc) for (n, vc) in n_vcs
		if not rep.p.loop_id (n)]

last_derive_case_split = [0]

def derive_case_split (rep, n_vcs, checks):
	last_derive_case_split[0] = (rep.p, n_vcs, checks)
	# remove duplicate pcs
	n_vcs_uniq = dict ([(rep.get_pc (n_vc), (i, n_vc))
		for (i, n_vc) in enumerate (n_vcs)]).values ()
	n_vcs = [n_vc for (i, n_vc) in sorted (n_vcs_uniq)]
	assert n_vcs
	tag = rep.p.node_tags[n_vcs[0][0]][0]
	keep_n_vcs = []
	test_n_vcs = n_vcs
	mk_thyps = lambda n_vcs: [rep_graph.pc_true_hyp ((n_vc, tag))
		for n_vc in n_vcs]
	while len (test_n_vcs) > 1:
		i = len (test_n_vcs) / 2
		test_in = test_n_vcs[:i]
		test_out = test_n_vcs[i:]
		checks2 = [(hyps + mk_thyps (test_in + keep_n_vcs), hyp, nm)
			for (hyps, hyp, nm) in checks]
		(verdict, _) = check.test_hyp_group (rep, checks2)
		if verdict:
			# forget n_vcs that were tested out
			test_n_vcs = test_in
		else:
			# focus on n_vcs that were tested out
			test_n_vcs = test_out
			keep_n_vcs.extend (test_in)
	[(n, vc)] = test_n_vcs
	return ('CaseSplit', ((n, tag), [n]))

def mk_seq_eqs (p, split, step, with_rodata):
	# eqs take the form of a number of constant expressions
	eqs = []

	# the variable 'loop' will be converted to the point in
	# the sequence - note this should be multiplied by the step size
	loop = mk_var ('%i', word32T)
	if step == 1:
		minus_loop_step = mk_uminus (loop)
	else:
		minus_loop_step = mk_times (loop, mk_word32 (- step))

	for (var, data) in get_loop_var_analysis_at (p, split):
		if data == 'LoopVariable':
			if with_rodata and var.typ == builtinTs['Mem']:
				eqs.append (logic.mk_rodata (var))
		elif data == 'LoopConst':
			if var.typ not in syntax.phantom_types:
				eqs.append (var)
		elif data == 'LoopLeaf':
			continue
		elif data[0] == 'LoopLinearSeries':
			(_, form, _) = data
			eqs.append (form (var,
				mk_cast (minus_loop_step, var.typ)))
		else:
			assert not 'var_deps type understood'

	k = ('extra_linear_seq_eqs', split, step)
	eqs += p.cached_analysis.get (k, [])

	return eqs

def c_memory_loop_invariant (p, c_sp, a_sp):
	def mem_vars (split):
		return [v for (v, data) in get_loop_var_analysis_at (p, split)
			if v.typ == builtinTs['Mem']
			if data == 'LoopVariable']

	if mem_vars (a_sp):
		return []
	# if ASM keeps memory constant through the loop, it is implying this
	# is semantically possible in C also, though it may not be
	# syntactically the case
	# anyway, we have to assert C memory equals *something* inductively
	# so we pick C initial memory.
	return mem_vars (c_sp)

def v_eqs_to_split (p, pair, v_eqs, restrs, hyps, tags = None):
	trace ('v_eqs_to_split: (%s, %s)' % pair)

	((l_n, l_init, l_step), (r_n, r_init, r_step)) = pair
	l_details = (l_n, (l_init, l_step), mk_seq_eqs (p, l_n, l_step, True)
		+ [v_i[0] for (v_i, v_j) in v_eqs if v_j == 'Const'])
	r_details = (r_n, (r_init, r_step), mk_seq_eqs (p, r_n, r_step, False)
		+ c_memory_loop_invariant (p, r_n, l_n))

	eqs = [(v_i[0], mk_cast (v_j[0], v_i[0].typ))
		for (v_i, v_j) in v_eqs if v_j != 'Const'
		if v_i[0] != syntax.mk_word32 (0)]

	n = 2
	split = (l_details, r_details, eqs, n, (n * r_step) - 1)
	trace ('Split: %s' % (split, ))
	if tags == None:
		tags = p.pairing.tags
	hyps = hyps + check.split_loop_hyps (tags, split, restrs, exit = True)

	r_max = get_split_limit (p, r_n, restrs, hyps, 'Offset',
		bound = (n + 2) * r_step, must_find = False,
		hints = [n * r_step, n * r_step + 1])
	if r_max == None:
		trace ('v_eqs_to_split: no RHS limit')
		return None

	if r_max > n * r_step:
		trace ('v_eqs_to_split: RHS limit not %d' % (n * r_step))
		return None
	trace ('v_eqs_to_split: split %s' % (split,))
	return split

def get_n_offset_successes (rep, sp, step, restrs):
	loop = rep.p.loop_body (sp)
	ns = [n for n in loop if rep.p.nodes[n].kind == 'Call']
	succs = []
	for i in range (step):
		for n in ns:
			vc = vc_offs (i + 1)
			if n == sp:
				vc = vc_offs (i)
			n_vc = (n, restrs + tuple ([(sp, vc)]))
			(_, _, succ) = rep.get_func (n_vc)
			pc = rep.get_pc (n_vc)
			succs.append (syntax.mk_implies (pc, succ))
	return succs

eq_ineq_ops = set (['Equals', 'Less', 'LessEquals',
	'SignedLess', 'SignedLessEquals'])

def split_linear_eq (cond):
	if cond.is_op ('Not'):
		[c] = cond.vals
		return split_linear_eq (c)
	elif cond.is_op (eq_ineq_ops):
		return (cond.vals[0], cond.vals[1])
	elif cond.is_op ('PArrayValid'):
		[htd, typ_expr, p, num] = cond.vals
		assert typ_expr.kind == 'Type'
		typ = typ_expr.val
		return split_linear_eq (logic.mk_array_size_ineq (typ, num, p))
	else:
		return None

def possibly_linear_ineq (cond):
	rv = split_linear_eq (cond)
	if not rv:
		return False
	(lhs, rhs) = rv
	return logic.possibly_linear (lhs) and logic.possibly_linear (rhs)

def linear_const_comparison (p, n, cond):
	"""examines a condition. if it is a linear (e.g. Less) comparison
	between a linear series variable and a loop-constant expression,
	return (linear side, const side), or None if not the case."""
	rv = split_linear_eq (cond)
	loop_head = p.loop_id (n)
	if not rv:
		return None
	(lhs, rhs) = rv
	zero = mk_num (0, lhs.typ)
	offs = logic.get_loop_linear_offs (p, loop_head)
	(lhs_offs, rhs_offs) = [offs (n, expr) for expr in [lhs, rhs]]
	oset = set ([lhs_offs, rhs_offs])
	if zero in oset and None not in oset and len (oset) > 1:
		if lhs_offs == zero:
			return (rhs, lhs)
		else:
			return (lhs, rhs)
	return None

def do_linear_rev_test (rep, restrs, hyps, split, eqs_assume, pred, large):
	p = rep.p
	(tag, _) = p.node_tags[split]
	checks = (check.single_loop_rev_induct_checks (p, restrs, hyps, tag,
			split, eqs_assume, pred)
		+ check.single_loop_rev_induct_base_checks (p, restrs, hyps,
			tag, split, large, eqs_assume, pred))

	groups = check.proof_check_groups (checks)
	for group in groups:
		(res, _) = check.test_hyp_group (rep, group)
		if not res:
			return False
	return True

def get_extra_assn_linear_conds (expr):
	if expr.is_op ('And'):
		return [cond for conj in logic.split_conjuncts (expr)
			for cond in get_extra_assn_linear_conds (conj)]
	if not expr.is_op ('Or'):
		return [expr]
	arr_vs = [v for v in expr.vals if v.is_op ('PArrayValid')]
	if not arr_vs:
		return [expr]
	[htd, typ_expr, p, num] = arr_vs[0].vals
	assert typ_expr.kind == 'Type'
	typ = typ_expr.val
	less_eq = logic.mk_array_size_ineq (typ, num, p)
	assn = logic.mk_align_valid_ineq (('Array', typ, num), p)
	return get_extra_assn_linear_conds (assn) + [less_eq]

def get_rhs_speculate_ineq (p, restrs, loop_head):
	assert p.loop_id (loop_head), loop_head
	loop_head = p.loop_id (loop_head)
	restrs = tuple ([(n, vc) for (n, vc) in restrs
		if p.node_tags[n][0] == p.node_tags[loop_head][0]])
	key = ('rhs_speculate_ineq', restrs, loop_head)
	if key in p.cached_analysis:
		return p.cached_analysis[key]

	res = rhs_speculate_ineq (p, restrs, loop_head)
	p.cached_analysis[key] = res
	return res

def get_new_rhs_speculate_ineq (p, restrs, hyps, loop_head):
	res = get_rhs_speculate_ineq (p, restrs, loop_head)
	if res == None:
		return None
	(point, _, (pred, _)) = res
	hs = [h for h in hyps if point in [n for ((n, _), _) in h.visits ()]
		if pred in h.get_vals ()]
	if hs:
		return None
	return res

def rhs_speculate_ineq (p, restrs, loop_head):
	"""code for handling an interesting case in which the compiler
	knows that the RHS program might fail in the future. for instance,
	consider a loop that cannot be exited until iterator i reaches value n.
	any error condition which implies i < b must hold of i - 1, thus
	n <= b.

	detects this case and identifies the inequality n <= b"""
	body = p.loop_body (loop_head)

	# if the loop contains function calls, skip it,
	# otherwise we need to figure out whether they terminate
	if [n for n in body if p.nodes[n].kind == 'Call']:
		return None

	exit_nodes = set ([n for n in body for n2 in p.nodes[n].get_conts ()
			if n2 != 'Err' if n2 not in body])
	assert set ([p.nodes[n].kind for n in exit_nodes]) <= set (['Cond'])

	# if there are multiple exit conditions, too hard for now
	if len (exit_nodes) > 1:
		return None

	[exit_n] = list (exit_nodes)
	rv = linear_const_comparison (p, exit_n, p.nodes[exit_n].cond)
	if not rv:
		return None
	(linear, const) = rv

	err_cond_sites = [(n, p.nodes[n].err_cond ()) for n in body]
	err_conds = set ([(n, cond) for (n, err_cond) in err_cond_sites
		if err_cond
		for assn in logic.split_conjuncts (mk_not (err_cond))
		for cond in get_extra_assn_linear_conds (assn)
		if possibly_linear_ineq (cond)])
	if not err_conds:
		return None

	assert const.typ.kind == 'Word'
	rep = rep_graph.mk_graph_slice (p)
	eqs = mk_seq_eqs (p, exit_n, 1, False)
	import loop_bounds
	eqs += loop_bounds.get_linear_series_eqs (p, exit_n,
                restrs, [], omit_standard = True)

	large = (2 ** const.typ.num) - 3
	const_less = lambda n: mk_less (const, mk_num (n, const.typ))
	less_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n,
		eqs, const_less (n), large)
	const_ge = lambda n: mk_less (mk_num (n, const.typ), const)
	ge_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n,
		eqs, const_ge (n), large)

	res = logic.binary_search_least (less_test, 1, large)
	if res:
		return (loop_head, (eqs, 1), (const_less (res), large))
	res = logic.binary_search_greatest (ge_test, 0, large)
	if res:
		return (loop_head, (eqs, 1), (const_ge (res), large))
	return None

def check_split_induct (p, restrs, hyps, split, tags = None):
	"""perform both the induction check and a function-call based check
	on successes which can avoid some problematic inductions."""
	((l_split, (_, l_step), _), (r_split, (_, r_step), _), _, n, _) = split
	if tags == None:
		tags = p.pairing.tags

	err_hyp = check.split_r_err_pc_hyp (p, split, restrs, tags = tags)
	hyps = [err_hyp] + hyps + check.split_loop_hyps (tags, split,
		restrs, exit = False)

	rep = mk_graph_slice (p)

	if not check.check_split_induct_step_group (rep, restrs, hyps, split,
			tags = tags):
		return False

	l_succs = get_n_offset_successes (rep, l_split, l_step, restrs)
	r_succs = get_n_offset_successes (rep, r_split, r_step, restrs)

	if not l_succs:
		return True

	hyp = syntax.foldr1 (syntax.mk_and, l_succs)
	if r_succs:
		hyp = syntax.mk_implies (foldr1 (syntax.mk_and, r_succs), hyp)

	return rep.test_hyp_whyps (hyp, hyps)

def init_loops_to_split (p, restrs):
	to_split = loops_to_split (p, restrs)

	return [n for n in to_split
		if not [n2 for n2 in to_split if n2 != n
			and p.is_reachable_from (n2, n)]]

def restr_others_both (p, restrs, n, m):
	extras = [(sp, vc_double_range (n, m))
		for sp in loops_to_split (p, restrs)]
	return restrs + tuple (extras)

def restr_others_as_necessary (p, n, restrs, init_bound, offs_bound,
		skip_loops = []):
	extras = [(sp, vc_double_range (init_bound, offs_bound))
		for sp in loops_to_split (p, restrs)
		if sp not in skip_loops
		if p.is_reachable_from (sp, n)]
	return restrs + tuple (extras)

def loop_no_match_unroll (rep, restrs, hyps, split, other_tag, unroll):
	p = rep.p
	assert p.node_tags[split][0] != other_tag
	restr = ((split, vc_num (unroll)), )
	restrs2 = restr_others (p, restr + restrs, 2)
	loop_cond = rep.get_pc ((split, restr + restrs))
	ret_cond = rep.get_pc (('Ret', restrs2), tag = other_tag)
	# loop should be reachable
	if rep.test_hyp_whyps (mk_not (loop_cond), hyps):
		trace ('Loop weak at %d (unroll count %d).' %
			(split, unroll))
		return True
	# reaching the loop should imply reaching a loop on the other side
	hyp = mk_not (mk_and (loop_cond, ret_cond))
	if not rep.test_hyp_whyps (hyp, hyps):
		trace ('Loop independent at %d (unroll count %d).' %
			(split, unroll))
		return True
	return False

def loop_no_match (rep, restrs, hyps, split, other_tag,
		check_speculate_ineq = False):
	if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 4):
		return False
	if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 8):
		return False
	if not check_speculate_ineq:
		return 'Restr'
	spec = get_new_rhs_speculate_ineq (rep.p, restrs, hyps, split)
	if not spec:
		return 'Restr'
	hyp = check.single_induct_resulting_hyp (rep.p, restrs, spec)
	hyps2 = hyps + [hyp]
	if not loop_no_match_unroll (rep, restrs, hyps2, split, other_tag, 8):
		return 'SingleRevInduct'
	return 'Restr'

last_searcher_results = []

def restr_point_name (p, n):
	if p.loop_id (n):
		return '%s (loop head)' % n
	elif p.loop_id (n):
		return '%s (in loop %d)' % (n, p.loop_id (n))
	else:
		return str (n)

def fail_searcher (p, restrs, hyps):
	return ('Fail Searcher', None)

def build_proof_rec (searcher, p, restrs, hyps, name = "problem"):
	trace ('doing build proof rec with restrs = %r, hyps = %r' % (restrs, hyps))
	if searcher == None:
		searcher = default_searcher

	(kind, details) = searcher (p, restrs, hyps)
	last_searcher_results.append ((p, restrs, hyps, kind, details, name))
	del last_searcher_results[:-10]
	if kind == 'Restr':
		(restr_kind, restr_points) = details
		printout ("Discovered that points [%s] can be bounded"
			% ', '.join ([restr_point_name (p, n)
				for n in restr_points]))
		printout ("  (in %s)" % name)
		restr_hints = [(n, restr_kind, True) for n in restr_points]
		return build_proof_rec_with_restrs (restr_hints,
			searcher, p, restrs, hyps, name = name)
	elif kind == 'Leaf':
		return ProofNode ('Leaf', None, ())
	assert kind in ['CaseSplit', 'Split', 'SingleRevInduct'], kind
	if kind == 'CaseSplit':
		(details, hints) = details
	probs = check.proof_subproblems (p, kind, details, restrs, hyps, name)
	if kind == 'CaseSplit':
		printout ("Decided to case split at %s" % str (details))
		printout ("  (in %s)" % name)
		restr_hints = [[(n, 'Number', False) for n in hints]
			for cases in [0, 1]]
	elif kind == 'SingleRevInduct':
		printout ('Found a future induction at %s' % str (details[0]))
		restr_hints = [[]]
	else:
		restr_points = check.split_heads (details)
		restr_hints = [[(n, rkind, True) for n in restr_points]
			for rkind in ['Number', 'Offset']]
		printout ("Discovered a loop relation for split points %s"
			% list (restr_points))
		printout ("  (in %s)" % name)
	subpfs = []
	for ((restrs, hyps, name), hints) in logic.azip (probs, restr_hints):
		printout ('Now doing proof search in %s.' % name)
		pf = build_proof_rec_with_restrs (hints, searcher,
			p, restrs, hyps, name = name)
		subpfs.append (pf)
	return ProofNode (kind, details, subpfs)

def build_proof_rec_with_restrs (split_hints, searcher, p, restrs,
		hyps, name = "problem"):
	if not split_hints:
		return build_proof_rec (searcher, p, restrs, hyps, name = name)

	(sp, kind, must_find) = split_hints[0]
	use_hyps = list (hyps)
	if p.node_tags[sp][0] != p.pairing.tags[1]:
		nrerr_hyp = check.non_r_err_pc_hyp (p.pairing.tags,
			restr_others (p, restrs, 2))
		use_hyps = use_hyps + [nrerr_hyp]

	if p.loop_id (sp):
		lim_pair = get_proof_split_limit (p, sp, restrs, use_hyps,
			kind, must_find = must_find)
	else:
		lim_pair = get_proof_visit_restr (p, sp, restrs, use_hyps,
			kind, must_find = must_find)

	if not lim_pair:
		assert not must_find
		return build_proof_rec_with_restrs (split_hints[1:],
			searcher, p, restrs, hyps, name = name)

	(min_v, max_v) = lim_pair
	if kind == 'Number':
		vc_opts = rep_graph.vc_options (range (min_v, max_v), [])
	else:
		vc_opts = rep_graph.vc_options ([], range (min_v, max_v))

	restrs = restrs + ((sp, vc_opts), )
	subproof = build_proof_rec_with_restrs (split_hints[1:],
		searcher, p, restrs, hyps, name = name)

	return ProofNode ('Restr', (sp, (kind, (min_v, max_v))), [subproof])

def get_proof_split_limit (p, sp, restrs, hyps, kind, must_find = False):
	limit = get_split_limit (p, sp, restrs, hyps, kind,
		must_find = must_find)
	if limit == None:
		return None
	# double-check this limit with a rep constructed without the 'fast' flag
	limit = find_split_limit (p, sp, restrs, hyps, kind,
		hints = [limit, limit + 1], use_rep = mk_graph_slice (p))
	return (0, limit + 1)

def get_proof_visit_restr (p, sp, restrs, hyps, kind, must_find = False):
	rep = rep_graph.mk_graph_slice (p)
	pc = rep.get_pc ((sp, restrs))
	if rep.test_hyp_whyps (pc, hyps):
		return (1, 2)
	elif rep.test_hyp_whyps (mk_not (pc), hyps):
		return (0, 1)
	else:
		assert not must_find
		return None

def default_searcher (p, restrs, hyps):
	# use any handy init splits
	res = init_proof_case_split (p, restrs, hyps)
	if res:
		return res

	# detect any un-split loops
	to_split_init = init_loops_to_split (p, restrs)
	rep = mk_graph_slice (p, fast = True)

	l_tag, r_tag = p.pairing.tags
	l_to_split = [n for n in to_split_init if p.node_tags[n][0] == l_tag]
	r_to_split = [n for n in to_split_init if p.node_tags[n][0] == r_tag]
	l_ep = p.get_entry (l_tag)
	r_ep = p.get_entry (r_tag)

	for r_sp in r_to_split:
		trace ('checking loop_no_match at %d' % r_sp, push = 1)
		res = loop_no_match (rep, restrs, hyps, r_sp, l_tag,
			check_speculate_ineq = True)
		if res == 'Restr':
			return ('Restr', ('Number', [r_sp]))
		elif res == 'SingleRevInduct':
			spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_sp)
			assert spec
			return ('SingleRevInduct', spec)
		trace (' .. done checking loop no match', push = -1)

	if l_to_split and not r_to_split:
		n = l_to_split[0]
		trace ('lhs loop alone, limit must be found.')
		return ('Restr', ('Number', [n]))

	if l_to_split:
		n = l_to_split[0]
		trace ('checking lhs loop_no_match at %d' % n, push = 1)
		if loop_no_match (rep, restrs, hyps, n, r_tag):
			trace ('loop does not match!', push = -1)
			return ('Restr', ('Number', [n]))
		trace (' .. done checking loop no match', push = -1)

		(kind, split) = find_split_loop (p, n, restrs, hyps)
		if kind == 'LoopUnroll':
			return ('Restr', ('Number', [split]))
		return (kind, split)

	if r_to_split:
		n = r_to_split[0]
		trace ('rhs loop alone, limit must be found.')
		return ('Restr', ('Number', [n]))

	return ('Leaf', None)

def use_split_searcher (p, split):
	xs = set ([p.loop_id (h) for h in check.split_heads (split)])
	def searcher (p, restrs, hyps):
		ys = set ([p.loop_id (h)
			for h in init_loops_to_split (p, restrs)])
		if xs <= ys:
			return ('Split', split)
		else:
			return default_searcher (p, restrs, hyps)
	return searcher

